#include "gdsm3.h"
#include <string.h>
#include <stdio.h>

#define FF0(x,y,z)                      ((x) ^ (y) ^ (z)) 
#define FF1(x,y,z)                      (((x) & (y)) | ( (x) & (z)) | ( (y) & (z)))
#define GG0(x,y,z)                      ((x) ^ (y) ^ (z)) 
#define GG1(x,y,z)                      (((x) & (y)) | ( (~(x)) & (z)))
#define SHL(x,n)                        (((x) & 0xFFFFFFFF) << n)
#define ROTL(x,n)                       (SHL((x),n) | ((x) >> (32 - n)))
#define P0(x)                           ((x) ^  ROTL((x),9) ^ ROTL((x),17)) 
#define P1(x)                           ((x) ^  ROTL((x),15) ^ ROTL((x),23)) 

#ifndef GET_ULONG_BE
#define GET_ULONG_BE(n, b, i)                       \
{                                                   \
    (n) = ((unsigned int)(b)[(i)] << 24)           \
        | ((unsigned int)(b)[(i) + 1] << 16)       \
        | ((unsigned int)(b)[(i) + 2] << 8)        \
        | ((unsigned int)(b)[(i) + 3]);            \
}
#endif

#ifndef PUT_ULONG_BE
#define PUT_ULONG_BE(n, b, i)                       \
{                                                   \
    (b)[(i)] = (unsigned char)((n) >> 24);          \
    (b)[(i) + 1] = (unsigned char)((n) >> 16);      \
    (b)[(i) + 2] = (unsigned char)((n) >> 8);       \
    (b)[(i) + 3] = (unsigned char)((n));            \
}
#endif

typedef struct
{
    unsigned int total[2];
    unsigned int state[8];
    unsigned char buffer[64];
} sm3_context;

static const unsigned char sm3_padding[64] =
{
 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
};

static void _sm3_init(sm3_context* ctx)
{
    ctx->total[0] = 0;
    ctx->total[1] = 0;

    ctx->state[0] = 0x7380166F;
    ctx->state[1] = 0x4914B2B9;
    ctx->state[2] = 0x172442D7;
    ctx->state[3] = 0xDA8A0600;
    ctx->state[4] = 0xA96F30BC;
    ctx->state[5] = 0x163138AA;
    ctx->state[6] = 0xE38DEE4D;
    ctx->state[7] = 0xB0FB0E4E;
}

static void _sm3_process(sm3_context* ctx, const unsigned char data[64])
{
    unsigned int SS1, SS2, TT1, TT2, W[68], W1[64];
    unsigned int A, B, C, D, E, F, G, H;
    unsigned int T[64];
    unsigned int Temp1, Temp2, Temp3, Temp4, Temp5;
    int j;

    for (j = 0; j < 16; j++)
    {
        T[j] = 0x79CC4519;
    }

    for (j =16; j < 64; j++)
    {
        T[j] = 0x7A879D8A;
    }

    GET_ULONG_BE(W[0], data, 0);
    GET_ULONG_BE(W[1], data, 4);
    GET_ULONG_BE(W[2], data, 8);
    GET_ULONG_BE(W[3], data, 12);
    GET_ULONG_BE(W[4], data, 16);
    GET_ULONG_BE(W[5], data, 20);
    GET_ULONG_BE(W[6], data, 24);
    GET_ULONG_BE(W[7], data, 28);
    GET_ULONG_BE(W[8], data, 32);
    GET_ULONG_BE(W[9], data, 36);
    GET_ULONG_BE(W[10], data, 40);
    GET_ULONG_BE(W[11], data, 44);
    GET_ULONG_BE(W[12], data, 48);
    GET_ULONG_BE(W[13], data, 52);
    GET_ULONG_BE(W[14], data, 56);
    GET_ULONG_BE(W[15], data, 60);

    for (j = 16; j < 68; j++)
    {
        Temp1 = W[j - 16] ^ W[j - 9];
        Temp2 = ROTL(W[j - 3], 15);
        Temp3 = Temp1 ^ Temp2;
        Temp4 = P1(Temp3);
        Temp5 =  ROTL(W[j - 13], 7) ^ W[j - 6];
        W[j] = Temp4 ^ Temp5;
    }

    for (j = 0; j < 64; j++)
    {
        W1[j] = W[j] ^ W[j + 4];
    }

    A = ctx->state[0];
    B = ctx->state[1];
    C = ctx->state[2];
    D = ctx->state[3];
    E = ctx->state[4];
    F = ctx->state[5];
    G = ctx->state[6];
    H = ctx->state[7];

    for (j = 0; j < 16; j++)
    {
        SS1 = ROTL((ROTL(A, 12) + E + ROTL(T[j], j)), 7); 
        SS2 = SS1 ^ ROTL(A, 12);
        TT1 = FF0(A, B, C) + D + SS2 + W1[j];
        TT2 = GG0(E, F, G) + H + SS1 + W[j];
        D = C;
        C = ROTL(B, 9);
        B = A;
        A = TT1;
        H = G;
        G = ROTL(F, 19);
        F = E;
        E = P0(TT2);
    }
    
    for (j = 16; j < 64; j++)
    {
        SS1 = ROTL((ROTL(A, 12) + E + ROTL(T[j], j)), 7); 
        SS2 = SS1 ^ ROTL(A, 12);
        TT1 = FF1(A, B, C) + D + SS2 + W1[j];
        TT2 = GG1(E, F, G) + H + SS1 + W[j];
        D = C;
        C = ROTL(B, 9);
        B = A;
        A = TT1;
        H = G;
        G = ROTL(F, 19);
        F = E;
        E = P0(TT2);
    }

    ctx->state[0] ^= A;
    ctx->state[1] ^= B;
    ctx->state[2] ^= C;
    ctx->state[3] ^= D;
    ctx->state[4] ^= E;
    ctx->state[5] ^= F;
    ctx->state[6] ^= G;
    ctx->state[7] ^= H;
}

static void _sm3_update(sm3_context* ctx, const unsigned char* input, int size)
{
    int fill;
    unsigned int left;

    if (size <= 0)
    {
        return;
    }

    left = ctx->total[0] & 0x3F;
    fill = 64 - left;

    ctx->total[0] += size;
    ctx->total[0] &= 0xFFFFFFFF;

    if (ctx->total[0] < (unsigned int)size)
    {
        ctx->total[1]++;
    }

    if (left && size >= fill)
    {
        memcpy((void *)(ctx->buffer + left), (void *)input, fill);
        _sm3_process(ctx, ctx->buffer);
        input += fill;
        size -= fill;
        left = 0;
    }

    while (size >= 64)
    {
        _sm3_process(ctx, input);
        input += 64;
        size  -= 64;
    }

    if (size > 0)
    {
        memcpy((void *)(ctx->buffer + left), (void *)input, size);
    }
}

static void _sm3_final(sm3_context* ctx, unsigned char output[32])
{
    unsigned int last, padn;
    unsigned int high, low;
    unsigned char msglen[8];

    high = (ctx->total[0] >> 29) | (ctx->total[1] << 3);
    low  = (ctx->total[0] << 3);

    PUT_ULONG_BE(high, msglen, 0);
    PUT_ULONG_BE(low,  msglen, 4);

    last = ctx->total[0] & 0x3F;
    padn = (last < 56) ? (56 - last) : (120 - last);

    _sm3_update(ctx, (unsigned char*)sm3_padding, padn);
    _sm3_update(ctx, msglen, 8);

    PUT_ULONG_BE(ctx->state[0], output, 0);
    PUT_ULONG_BE(ctx->state[1], output, 4);
    PUT_ULONG_BE(ctx->state[2], output, 8);
    PUT_ULONG_BE(ctx->state[3], output, 12);
    PUT_ULONG_BE(ctx->state[4], output, 16);
    PUT_ULONG_BE(ctx->state[5], output, 20);
    PUT_ULONG_BE(ctx->state[6], output, 24);
    PUT_ULONG_BE(ctx->state[7], output, 28);
}

int sm3_buffer(const unsigned char* buff, int size, unsigned char* out_buff, int out_size)
{
    sm3_context ctx;
    unsigned char output[32] = {0};
    int i = 0;

    _sm3_init(&ctx);
    _sm3_update(&ctx, buff, size);
    _sm3_final(&ctx, output);

    for (i = 0; i < 32; ++i)
    {
        if (2 * i >= out_size)
        {
            break;
        }

        sprintf((void *)out_buff + 2 * i, "%02x", output[i]);
    }
    return 0;
}

/*
static int _sm3_test()
{
    // input: abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd
    // output: debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732
    const char* test = "abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd";
    char buff[65] = {0};
    sm3_buffer(test, strlen(test), buff, 64);
    printf(buff);
}
*/
